import os
import sys

import numpy as np
np.random.seed(7) # for reproducibility
import pandas as pd

import tensorflow as tf

import keras.backend as K
from keras.models import Model, load_model

import umap
import hdbscan

import sys
sys.path.append("../Scripts/")
from IntegratedGradients import *
from util_funcs import *
from plotseqlogo import seqlogo, seqlogo_fig

#####################################################
###Function to extract motifs from attribution maps###
#####################################################

def get_top_n_motif(scores, n, exclude_neighbour):
    scores = scores.copy()
    if (n==1):
        top_n_scores = np.max(scores)
        top_n_ind = np.argmax(scores)
    else:
        top_n_scores = []
        top_n_ind = []
        for j in range(n):
            max_idx = np.argmax(scores)
            top_n_scores.append(scores[max_idx])
            top_n_ind.append(max_idx)
            scores[max_idx-exclude_neighbour:max_idx+exclude_neighbour-1] = -np.inf
    return np.array(top_n_scores), np.array(top_n_ind)



def norm_mm(x,newmin, newmax):
    oldmin=min(x.ravel())
    oldmax=max(x.ravel())
    return (newmax-newmin)/(oldmax-oldmin)*(x-oldmax)+newmax


def pool_sum(attr_scores,motif_size):
    pool = K.pool2d(K.variable(attr_scores[:,:,None,:]),pool_size=(motif_size,1),strides=(1, 1),padding='valid',data_format="channels_last",pool_mode='avg')
    #init = tf.initialize_all_variables()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        pool = sess.run(pool)
    pool = np.asarray(pool).squeeze()
    pool_sum=np.sum(pool,axis=2) 
    return pool_sum


def extract_motif(attr_scores,ind,RBP_name,motif_size=6,n=1,exclude_neighbour=5,path="./"):
    import pickle
    motifs=[]
    aggres = pool_sum(attr_scores,motif_size)
    for i in range(len(ind)):
        top_n_scores , top_n_ind = get_top_n_motif(aggres[i], n, exclude_neighbour)
        if n!=1:
            for j in range(len(top_n_ind)):
                motif=norm_mm(np.transpose(attr_scores[i][top_n_ind[j]:top_n_ind[j]+motif_size,:].clip(min=0, max=None)),0,1)
                motifs.append(motif)
        else:
            motif=norm_mm(np.transpose(attr_scores[i][top_n_ind:top_n_ind+motif_size,:].clip(min=0, max=None)),0,1)
            motifs.append(motif)
    with open(path+RBP_name+"top"+str(n)+"_motifs_size"+str(motif_size), "wb") as f:
        pickle.dump(motifs, f)


def get_motif_withregion(X_test_seq, X_test_reg, y_test, RBP_name, RBP_index, pred, igres, n, motif_size, path):
    ind = [i for i in range(y_test.shape[0]) if y_test[i,RBP_index] == 1 and pred[i,RBP_index] > 0.50]
    print( RBP_name + " : " + str(len(ind)) + " out of " + str(sum(y_test[:,RBP_index])) )
    ex_seq = np.array([igres.explain([X_test_seq[i],X_test_reg[i]],outc=RBP_index,reference=False)[0] for i in ind])
    print(ex_seq.shape)
    extract_motif(ex_seq,ind,RBP_name,motif_size=motif_size,n=n,exclude_neighbour=5,path=path)


###############################################################
###Function to cluster motifs obtained from attribution maps###
###############################################################
def umap_motif( motifs ):
    import umap
    motifsarray=[motif.ravel() for motif in motifs]
    X = np.array(motifsarray)
    reducer = umap.UMAP(n_components=2,random_state=42)
    X_embedded = reducer.fit_transform(X)
    plt.close("all")
    plt.scatter(X_embedded[:, 0], X_embedded[:, 1])
    plt.gca().set_aspect('equal', 'datalim')
    plt.title('UMAP projection of the motifs', fontsize=20);
    plt.show()
    return X_embedded

def cluster_motifs(RBPname, path_to_motifs,min_samples=5,min_cluster_size=10):
    import pickle
    import umap
    import hdbscan
    import pickle
    import seaborn as sns
    with open(path_to_motifs, "rb") as f:
        motifs=pickle.load(f)
    motifs_embedded = umap_motif ( motifs )
    labels = hdbscan.HDBSCAN(min_samples=min_samples,min_cluster_size=min_cluster_size,).fit_predict(motifs_embedded)
    clustered = (labels >= 0)
    plt.close("all")
    plt.scatter(motifs_embedded[clustered, 0],motifs_embedded[clustered, 1],c=labels[clustered],s=2,cmap='rainbow')
    plt.gca().set_aspect('equal', 'datalim')
    plt.title('clusters of motifs', fontsize=20);
    plt.show()
    dfmotif=pd.DataFrame()
    dfmotif['x-embedded'] = motifs_embedded[clustered, 0]
    dfmotif['y-embedded'] = motifs_embedded[clustered, 1]
    dfmotif['cluster'] = pd.factorize(labels[clustered])[0]
    sns.lmplot(data=dfmotif, x='x-embedded', y='y-embedded', hue='cluster', fit_reg=False, legend=True, legend_out=True)
    plt.gca().set_aspect('equal', 'datalim')
    plt.title('clusters of motifs', fontsize=20);
    plt.show()
    motifscluster_index = dfmotif.groupby('cluster').apply(lambda x: x.index.tolist())
    motifsnp = motifs.copy()
    motifsnp = np.asarray(motifsnp)[clustered]
    mergemotif = []
    num_motifs = []
    for i in range(len(motifscluster_index)):
        mergemotif.append( np.sum(motifsnp[np.array(motifscluster_index[i])],axis=0)/len(motifs))
        num_motifs.append(len(motifscluster_index[i]))
        plt.close("all")
        seqlogo_fig(np.transpose(mergemotif[i]), vocab="RNA", figsize=(15,3), ncol=1, plot_name = ["cluster"+str(i)])
        plt.show()
    return mergemotif,num_motifs


##########################################################
###Function to aligne motifs obtained from clustering ###
##########################################################


def AllMotifAlignments(motif1,motif2,min_align=4,exclude_zero=False):
    from scipy.spatial import distance
    from itertools import product
    #offset = list(product(range(motif1.shape[1]-min_align+1), repeat=2))
    offset = list(product(range(motif1.shape[1]-min_align+1), range(motif2.shape[1]-min_align+1)))
    scores = np.zeros(len(offset))
    for i,k in enumerate(offset):
        aligned_motif1, aligned_motif2 = maxAligned(motif1, motif2, k[0], k[1])
        res = np.array([1-distance.correlation(pos1, pos2) for pos1,pos2 in zip(aligned_motif1,aligned_motif2)])
        res = res[~np.isnan(res)]
        scores[i] = np.sum(res)/len(res)
        #scores = np.array(scores)
    return np.max(scores), offset[np.argmax(scores)]


def maxAligned (motif1, motif2, offset1, offset2):
    motif1 = motif1[:,offset1:]
    motif2 = motif2[:,offset2:]

    max_col = min(motif1.shape[1],motif2.shape[1])

    motif1 = motif1[:,:max_col]
    motif2 = motif2[:,:max_col]
    return motif1, motif2



def merge_aligned(motif1,motif2,min_align=4,exclude_zero=False,tr=0.8):
    score , offset = AllMotifAlignments(motif1,motif2,min_align=4,exclude_zero=False)
    if score > tr :
        motif1 = motif1[:,offset[0]:] # cut off beginning of motif1
        motif2 = motif2[:,offset[1]:] # cut off beginning of motif2

        max_col = min(motif1.shape[1],motif2.shape[1])

        motif1 = motif1[:,:max_col]
        motif2 = motif2[:,:max_col]
        res = (motif1+motif2)/2
    else:
        res=0
    return res


def pad2(array, offset, motif_len, shift):
    # Create an array of zeros with the reference shape
    result = np.zeros((4,motif_len))
    #insertHere = slice(2-offset, 2-offset+array.shape[1])
    if shift==0:
        insertHere = slice(offset, offset+array.shape[1])
    else:
        insertHere = slice(shift-offset, shift-offset+array.shape[1])
    result[:,insertHere] = array
    return result



def motif_consensus(mergemotif,num_motifs,consensus_length,tr,min_align=4):
    index=list(range(len(mergemotif)))
    consensus_motifs=[]
    refind = num_motifs.index(max(num_motifs))
    refmotif=mergemotif[refind]
    motif_length=refmotif.shape[1]
    shift= int((consensus_length - motif_length)/2)
    index=list(set(index)-set([refind]))
    while len(index)>0:
        scores=[AllMotifAlignments(refmotif,mergemotif[k],min_align=4,exclude_zero=False)[0] for k in index]
        ind=index[np.argmax(scores)]
        if (np.max(scores) > tr):
            ind=index[np.argmax(scores)]
            offset=AllMotifAlignments(refmotif,mergemotif[ind],min_align=4,exclude_zero=False)[1]
            consensus=(pad2(refmotif,offset[0],consensus_length,consensus_length-refmotif.shape[1])+pad2(mergemotif[ind],offset[1],consensus_length,consensus_length-mergemotif[ind].shape[1]))/2
            refmotif=consensus[:,shift:consensus_length-shift]
            index=list(set(index)-set([ind]))
            if len(index)==0:
                print("shape2: " + str(refmotif.shape))
                consensus_motifs.append(refmotif)
                break
            continue
        else:
            consensus_motifs.append(refmotif)
            #seqlogo_fig(np.transpose(refmotif), vocab="RNA", figsize=(15,3), ncol=1)
            refind=index[0]
            refmotif=mergemotif[refind]
            index=list(set(index)-set([refind]))
            if len(index) == 0:
                consensus_motifs.append(refmotif)
            continue
    for k in range(len(consensus_motifs)):
        consensus_motifs[k] = np.pad(consensus_motifs[k], [(0,0), ((0,consensus_length-consensus_motifs[k].shape[1]))], mode='constant')
    print(consensus_motifs[k].shape)
    return consensus_motifs
